# Copyright 2022 Twitter, Inc and Zhendong Wang.
# SPDX-License-Identifier: Apache-2.0

from gobal import debug_params
import argparse
import gym
import numpy as np
import os
import torch
import json

import d4rl
from utils import utils
from utils.data_sampler import Data_Sampler
from utils.logger import logger, setup_logger
from torch.utils.tensorboard import SummaryWriter
from utils.time_utils import RemainingTime

hyperparameters = {
    'halfcheetah-medium-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                              'num_epochs': 2000, 'gn': 9.0, 'top_k': 1},
    'hopper-medium-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                         'num_epochs': 2000, 'gn': 9.0, 'top_k': 2},
    'walker2d-medium-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                           'num_epochs': 2000, 'gn': 1.0, 'top_k': 1},
    'halfcheetah-medium-replay-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no',
                                     'eval_freq': 50, 'num_epochs': 2000, 'gn': 2.0, 'top_k': 0},
    'hopper-medium-replay-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                                'num_epochs': 2000, 'gn': 4.0, 'top_k': 2},
    'walker2d-medium-replay-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                                  'num_epochs': 2000, 'gn': 4.0, 'top_k': 1},
    'halfcheetah-medium-expert-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no',
                                     'eval_freq': 50, 'num_epochs': 2000, 'gn': 7.0, 'top_k': 0},
    'hopper-medium-expert-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                                'num_epochs': 2000, 'gn': 5.0, 'top_k': 2},
    'walker2d-medium-expert-v2': {'lr': 3e-4, 'eta': 1.0, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                                  'num_epochs': 2000, 'gn': 5.0, 'top_k': 1},
    'antmaze-umaze-v0': {'lr': 3e-4, 'eta': 0.5, 'max_q_backup': False, 'reward_tune': 'cql_antmaze', 'eval_freq': 50,
                         'num_epochs': 1000, 'gn': 2.0, 'top_k': 2},
    'antmaze-umaze-diverse-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze',
                                 'eval_freq': 50, 'num_epochs': 1000, 'gn': 3.0, 'top_k': 2},
    'antmaze-medium-play-v0': {'lr': 3e-4, 'eta': 2.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze',
                               'eval_freq': 50, 'num_epochs': 1000, 'gn': 2.0, 'top_k': 1},
    'antmaze-medium-diverse-v0': {'lr': 3e-4, 'eta': 3.0, 'max_q_backup': True, 'reward_tune': 'cql_antmaze',
                                  'eval_freq': 50, 'num_epochs': 1000, 'gn': 1.0, 'top_k': 1},
    'antmaze-large-play-v0': {'lr': 3e-4, 'eta': 4.5, 'max_q_backup': True, 'reward_tune': 'cql_antmaze',
                              'eval_freq': 50, 'num_epochs': 1000, 'gn': 10.0, 'top_k': 2},
    'antmaze-large-diverse-v0': {'lr': 3e-4, 'eta': 3.5, 'max_q_backup': True, 'reward_tune': 'cql_antmaze',
                                 'eval_freq': 50, 'num_epochs': 1000, 'gn': 7.0, 'top_k': 1},
    'pen-human-v1': {'lr': 3e-5, 'eta': 0.15, 'max_q_backup': False, 'reward_tune': 'normalize', 'eval_freq': 50,
                     'num_epochs': 1000, 'gn': 7.0, 'top_k': 2},
    'pen-cloned-v1': {'lr': 3e-5, 'eta': 0.1, 'max_q_backup': False, 'reward_tune': 'normalize', 'eval_freq': 50,
                      'num_epochs': 1000, 'gn': 8.0, 'top_k': 2},
    'kitchen-complete-v0': {'lr': 3e-4, 'eta': 0.005, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                            'num_epochs': 250, 'gn': 9.0, 'top_k': 2},
    'kitchen-partial-v0': {'lr': 3e-4, 'eta': 0.005, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                           'num_epochs': 1000, 'gn': 10.0, 'top_k': 2},
    'kitchen-mixed-v0': {'lr': 3e-4, 'eta': 0.005, 'max_q_backup': False, 'reward_tune': 'no', 'eval_freq': 50,
                         'num_epochs': 1000, 'gn': 10.0, 'top_k': 0},
}


def train_agent(env, state_dim, action_dim, max_action, device, output_dir, args):
    # Load buffer
    dataset = d4rl.qlearning_dataset(env)
    data_sampler = Data_Sampler(dataset, device, args.reward_tune)
    utils.print_banner('Loaded buffer')

    if args.algo == 'ql':
        from agents.ql_diffusion import DiffusionQL as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      max_q_backup=args.max_q_backup,
                      beta_schedule=args.beta_schedule,
                      n_timesteps=args.T,
                      eta=args.eta,
                      lr=args.lr,
                      lr_decay=args.lr_decay,
                      lr_maxt=args.num_epochs,
                      grad_norm=args.gn)
    elif args.algo == 'qgd':
        from agents.quantileguideddiffusion import QuantileGuidedDiffusion as Agent
        using_guide = False
        prob_unconditional = 0.1
        guide_weight = 4.
        iql_quantile_point = 0.9 if 'antmaze' in args.env_name else 0.7
        critic_policy = 'current'  # {'current', 'behavior'}
        critic_loss_type = 'original'  # {'original', 'expectile', 'quantile-101'}
        assert (critic_policy in ['current', 'behavior'] and critic_loss_type == 'original') \
               or (critic_policy == 'expectile_behavior' and critic_loss_type in ['expectile', 'quantile-101'])
        filter_policy = 'current'  # {'current', 'behavior', 'same'}
        filter_loss_type = 'quantile-101'  # {'original', 'expectile', 'quantile-101'}
        imitation_method = 'conditional'  # {'exp_adv_weight', 'filter', 'quantile', 'expectile', 'conditional'}
        balance_condition_reweight = True
        hard_updating_target_q = False
        weight_temperature = 1.
        training_epochs = None
        if imitation_method == 'quantile':  # wait adjust
            critic_policy = 'behavior'
            using_guide = True
            critic_loss_type = 'quantile-101'
        elif imitation_method == 'expectile':  # wait adjust
            critic_policy = 'behavior'
            using_guide = True
            critic_loss_type = 'expectile'
            iql_quantile_point = 0.95
        elif imitation_method == 'filter':  # wait adjust
            using_guide = True
            prob_unconditional = 0.1
            guide_weight = 10.
            iql_quantile_point = 0.9 if 'antmaze' in args.env_name else 0.9
            critic_policy = 'current'  # {'current', 'behavior'}
            critic_loss_type = 'original'  # {'original', 'expectile', 'quantile-101'}
            filter_policy = 'current'  # {'current', 'behavior', 'same'}
            filter_loss_type = 'quantile-101'  # {'original', 'expectile', 'quantile-101'}
            imitation_method = 'filter'  # {'exp_adv_weight', 'filter', 'quantile', 'expectile'}
            weight_temperature = 1.
            args.batch_size = int(args.batch_size / (1 - iql_quantile_point))
            args.T = 5
        elif imitation_method == 'conditional':  # wait adjust
            using_guide = True
            prob_unconditional = 0.1
            guide_weight = 1.  # adjust?
            iql_quantile_point = 0.9 if 'antmaze' in args.env_name else 0.7  # adjust?
            critic_policy = 'behavior'  # {'current', 'behavior'}
            critic_loss_type = 'expectile'  # {'original', 'expectile', 'quantile-101'}
            assert iql_quantile_point < 1. or critic_loss_type == 'original'
            filter_policy = 'behavior'  # {'current', 'behavior', 'same'}
            filter_loss_type = 'quantile-11'  # {'original', 'expectile', 'quantile-101'}
            imitation_method = 'quantile'  # {'exp_adv_weight', 'filter', 'quantile', 'ge_quantile', 'quantile_weight', 'expectile'}
            # adjust?
            balance_condition_reweight = False  # adjust?
            balance_sample_reweight = False
            condition_pos_embed = True
            n_hard_updating_target_q = None  # adjust?
            weight_temperature = 1.
            args.batch_size = 2560  # int(args.batch_size / (1 - min(0.95, iql_quantile_point)))
            args.T = 5
            # ----
            # training_epochs = None
            # training_epochs = {
            #     'training_critic': [0, 400000],
            #     'training_filter': [0, 400000],
            #     'training_actor': [0, 400000],
            #     'eval': [1, 400000]
            # }
            # training_epochs = {
            #     'training_critic': [0, 2000],
            #     'training_filter': [2000, 4000],
            #     'training_actor':  [4000 + 1, 800000],
            #     'eval': [4000 + 1, 800000]
            # }
            training_epochs = {
                'training_critic': [0, 400000],
                'training_filter': [0, 400000],
                'training_actor': [400000, 800000],
                'eval': [800000, 800000 + 1]
            }
            # ----
            # eval_hyperparameters = [(guide_weight, iql_quantile_point)]
            eval_hyperparameters = [
                (w, q) for w in [1., 3., 0.3, 0.1, 10., 0.03, 30.] for q in [0.5, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 2.0]
            ]
        print('iql_config', '\n',
              'using_guide', using_guide, '\n',
              'prob_unconditional', prob_unconditional, '\n',
              'guide_weight', guide_weight, '\n',
              'iql_quantile_point', iql_quantile_point, '\n',
              'critic_policy', critic_policy, '\n',
              'critic_loss_type', critic_loss_type, '\n',
              'filter_policy', filter_policy, '\n',
              'filter_loss_type', filter_loss_type, '\n',
              'imitation_method', imitation_method, '\n',
              'balance_condition_reweight', balance_condition_reweight, '\n',
              'balance_sample_reweight', balance_sample_reweight, '\n',
              'n_hard_updating_target_q', n_hard_updating_target_q, '\n',
              'weight_temperature', weight_temperature, '\n',
              'arg.batch_size', args.batch_size, '\n',
              'args.T', args.T, '\n',
              'condition_pos_embed', condition_pos_embed, '\n',
              'training_epochs', training_epochs, '\n',
              )
        args.eval_freq = 100  # debug
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      using_guide=using_guide,
                      prob_unconditional=prob_unconditional,
                      guide_weight=guide_weight,
                      iql_quantile_point=iql_quantile_point,
                      critic_policy=critic_policy,
                      critic_loss_type=critic_loss_type,
                      filter_policy=filter_policy,
                      filter_loss_type=filter_loss_type,
                      imitation_method=imitation_method,
                      balance_condition_reweight=balance_condition_reweight,
                      balance_sample_reweight=balance_sample_reweight,
                      n_hard_updating_target_q=n_hard_updating_target_q,
                      weight_temperature=weight_temperature,
                      max_q_backup=args.max_q_backup,
                      beta_schedule=args.beta_schedule,
                      n_timesteps=args.T,
                      eta=args.eta,
                      lr=args.lr,
                      lr_decay=False,
                      lr_maxt=args.num_epochs,
                      grad_norm=args.gn,
                      condition_pos_embed=condition_pos_embed,
                      training_epochs=training_epochs
                      )
    elif args.algo == 'bc':
        from agents.bc_diffusion import Diffusion_BC as Agent
        agent = Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=args.discount,
                      tau=args.tau,
                      beta_schedule=args.beta_schedule,
                      n_timesteps=args.T,
                      lr=args.lr)

    early_stop = False
    stop_check = utils.EarlyStopping(tolerance=1, min_delta=0.)
    writer = None  # SummaryWriter(output_dir)

    evaluations = []
    training_iters = 0
    max_timesteps = args.num_epochs * args.num_steps_per_epoch
    metric = 100.
    utils.print_banner(f"Training Start", separator="*", num_star=90)
    rt = RemainingTime(max_timesteps / int(args.eval_freq * args.num_steps_per_epoch), smoothing_coef=0.99)
    while (training_iters < max_timesteps) and (not early_stop):
        rt.start_each_iter()

        iterations = int(args.eval_freq * args.num_steps_per_epoch)
        for i in range(int(args.eval_freq)):
            loss_metric = agent.train(data_sampler,
                                      iterations=args.num_steps_per_epoch,
                                      batch_size=args.batch_size,
                                      log_writer=writer)
            training_iters += args.num_steps_per_epoch
            curr_epoch = int(training_iters // int(args.num_steps_per_epoch))

            # Logging
            utils.print_banner(f"Train step: {training_iters}", separator="*", num_star=90)
            logger.record_tabular('Trained Epochs', curr_epoch)
            logger.record_tabular('Target Q', np.mean(loss_metric['target_q']))
            logger.record_tabular('BC Loss', np.mean(loss_metric['bc_loss']))
            logger.record_tabular('QL Loss', np.mean(loss_metric['ql_loss']))
            logger.record_tabular('Actor Loss', np.mean(loss_metric['actor_loss']))
            logger.record_tabular('Q Critic Loss', np.mean(loss_metric['q_critic_loss']))
            logger.record_tabular('V Critic Loss', np.mean(loss_metric['v_critic_loss']))
            logger.record_tabular('Filter Critic Loss', np.mean(loss_metric['filter_critic_loss']))
            logger.record_tabular('Critic Loss', np.mean(loss_metric['critic_loss']))
            logger.dump_tabular()

        # Evaluation
        if training_epochs is None or training_epochs['eval'][0] <= training_iters <= training_epochs['eval'][1]:
            for param in eval_hyperparameters:
                guide_weight, iql_quantile_point = param
                print('guide_weight', guide_weight, 'iql_quantile_point', iql_quantile_point)
                agent.actor.guide_weight = guide_weight
                agent.iql_quantile_point = iql_quantile_point
                eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = \
                    eval_policy(agent, args.env_name, args.seed, eval_episodes=args.eval_episodes)

                evaluations.append([eval_res, eval_res_std, eval_norm_res, eval_norm_res_std,
                                    np.mean(loss_metric['bc_loss']), np.mean(loss_metric['ql_loss']),
                                    np.mean(loss_metric['actor_loss']),
                                    np.mean(loss_metric['q_critic_loss']), np.mean(loss_metric['v_critic_loss']),
                                    np.mean(loss_metric['critic_loss']),
                                    curr_epoch])
                np.save(os.path.join(output_dir, "eval"), evaluations)
                logger.record_tabular('Average Episodic Reward', eval_res)
                logger.record_tabular('Average Episodic N-Reward', eval_norm_res)
                logger.dump_tabular()
                print(training_iters, eval_res, eval_res_std, eval_norm_res, eval_norm_res_std)

                bc_loss = np.mean(loss_metric['bc_loss'])
                if args.early_stop:
                    early_stop = stop_check(metric, bc_loss)

                metric = bc_loss

                if args.save_best_model:
                    agent.save_model(output_dir, curr_epoch)

        rt.end_each_iter()

    # Model Selection: online or offline
    scores = np.array(evaluations)
    if args.ms == 'online':
        best_id = np.argmax(scores[:, 2])
        best_res = {'model selection': args.ms, 'epoch': scores[best_id, -1],
                    'best normalized score avg': scores[best_id, 2],
                    'best normalized score std': scores[best_id, 3],
                    'best raw score avg': scores[best_id, 0],
                    'best raw score std': scores[best_id, 1]}
        with open(os.path.join(output_dir, f"best_score_{args.ms}.txt"), 'w') as f:
            f.write(json.dumps(best_res))
    elif args.ms == 'offline':
        bc_loss = scores[:, 4]
        top_k = min(len(bc_loss) - 1, args.top_k)
        where_k = np.argsort(bc_loss) == top_k
        best_res = {'model selection': args.ms, 'epoch': scores[where_k][0][-1],
                    'best normalized score avg': scores[where_k][0][2],
                    'best normalized score std': scores[where_k][0][3],
                    'best raw score avg': scores[where_k][0][0],
                    'best raw score std': scores[where_k][0][1]}

        with open(os.path.join(output_dir, f"best_score_{args.ms}.txt"), 'w') as f:
            f.write(json.dumps(best_res))

    # writer.close()


# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def eval_policy(policy, env_name, seed, eval_episodes=10):
    eval_env_vector = gym.vector.make(env_name, num_envs=eval_episodes, asynchronous=True)
    eval_env_vector.seed(seed + 100)
    eval_env = gym.make(env_name)

    traj_done = np.zeros([eval_episodes])
    traj_return = np.zeros([eval_episodes])
    state = eval_env_vector.reset()
    i = 0
    while (traj_done > 0).sum() != eval_episodes:
        action = policy.sample_action(np.array(state))
        state, reward, done, _ = eval_env_vector.step(action)
        traj_return = traj_return + reward * (traj_done == 0)
        traj_done = traj_done + done
        i += 1
        if i % 2000 == 0:
            print('iter', i)    # debug
    scores = traj_return.tolist()
    for score in scores:
        print('_', score, eval_env.get_normalized_score(score))
    avg_reward = np.mean(scores)
    std_reward = np.std(scores)

    normalized_scores = [eval_env.get_normalized_score(s) for s in scores]
    avg_norm_score = eval_env.get_normalized_score(avg_reward)
    std_norm_score = np.std(normalized_scores)

    utils.print_banner(f"Evaluation over {eval_episodes} episodes: {avg_reward:.2f} {avg_norm_score:.2f}")
    return avg_reward, std_reward, avg_norm_score, std_norm_score


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ### Experimental Setups ###
    parser.add_argument("--exp", default='exp_1', type=str)  # Experiment ID
    parser.add_argument('--device', default=0, type=int)  # device, {"cpu", "cuda", "cuda:0", "cuda:1"}, etc
    parser.add_argument("--env_name", default="walker2d-medium-expert-v2", type=str)  # OpenAI gym environment name
    parser.add_argument("--dir", default="results", type=str)  # Logging directory
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--num_steps_per_epoch", default=1000, type=int)

    ### Optimization Setups ###
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--lr_decay", action='store_true')
    parser.add_argument('--early_stop', action='store_true')
    parser.add_argument('--save_best_model', action='store_true')

    ### RL Parameters ###
    parser.add_argument("--discount", default=0.99, type=float)
    parser.add_argument("--tau", default=0.005, type=float)

    ### Diffusion Setting ###
    parser.add_argument("--T", default=5, type=int)
    parser.add_argument("--beta_schedule", default='vp', type=str)
    ### Algo Choice ###
    parser.add_argument("--algo", default="qgd", type=str, help="['bc', 'ql', 'qgd']")
    parser.add_argument("--ms", default='offline', type=str, help="['online', 'offline']")
    # parser.add_argument("--top_k", default=1, type=int)

    # parser.add_argument("--lr", default=3e-4, type=float)
    # parser.add_argument("--eta", default=1.0, type=float)
    # parser.add_argument("--max_q_backup", action='store_true')
    # parser.add_argument("--reward_tune", default='no', type=str)
    # parser.add_argument("--gn", default=-1.0, type=float)

    args = parser.parse_args()
    args.device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    args.output_dir = f'{args.dir}'

    args.num_epochs = hyperparameters[args.env_name]['num_epochs']
    args.eval_freq = hyperparameters[args.env_name]['eval_freq']
    args.eval_episodes = 10 if 'v2' in args.env_name else 100

    args.lr = hyperparameters[args.env_name]['lr']
    args.eta = hyperparameters[args.env_name]['eta']
    args.max_q_backup = hyperparameters[args.env_name]['max_q_backup']
    args.reward_tune = hyperparameters[args.env_name]['reward_tune']
    args.gn = hyperparameters[args.env_name]['gn']
    args.top_k = hyperparameters[args.env_name]['top_k']

    # Setup Logging
    file_name = f"{args.env_name}|{args.exp}|diffusion-{args.algo}|T-{args.T}"
    if args.lr_decay: file_name += '|lr_decay'
    file_name += f'|ms-{args.ms}'

    if args.ms == 'offline': file_name += f'|k-{args.top_k}'
    file_name += f'|{args.seed}'

    results_dir = os.path.join(args.output_dir, file_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    utils.print_banner(f"Saving location: {results_dir}")
    # if os.path.exists(os.path.join(results_dir, 'variant.json')):
    #     raise AssertionError("Experiment under this setting has been done!")
    variant = vars(args)
    variant.update(version=f"Diffusion-Policies-RL")

    env = gym.make(args.env_name)

    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    variant.update(state_dim=state_dim)
    variant.update(action_dim=action_dim)
    variant.update(max_action=max_action)
    setup_logger(os.path.basename(results_dir), variant=variant, log_dir=results_dir)
    utils.print_banner(f"Env: {args.env_name}, state_dim: {state_dim}, action_dim: {action_dim}")

    train_agent(env,
                state_dim,
                action_dim,
                max_action,
                args.device,
                results_dir,
                args)
